
# {G-formula = {Y, M, L}}

process_data <- function(dat, a, m, l){
 
 dat$A = a 
 dat$M = m 
 dat$L = l 
 dat$AC = a*dat$C
 dat$AM = a*m
 dat$AL = a*l
 dat$ACM = a*dat$C*m
 dat$AML = a*m*l
 dat = as.matrix(dat)
 
 return(dat)
}

# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# G-formula: 
# PSE: Y(a, m(a), l(a'))
# =\sum_{i, M, L} {E[Y|l,m,a=1,ci]*p(m|a=1,ci) - E[Y|l,m,a=0,ci]*p(m|a=0,ci)}*p(l|a=0,m,ci) * pi
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

compute_effect <- function(dat, beta, px, opt){
 
 n = nrow(dat)
 
 reparam = opt$reparam
 
 beta_y = beta$beta_y
 beta_l = beta$beta_l
 beta_m = beta$beta_m
 
 dat_a0m0l0 = process_data(dat, a = 0, m = 0, l = 0)
 dat_a0m0l1 = process_data(dat, a = 0, m = 0, l = 1)
 dat_a0m1l0 = process_data(dat, a = 0, m = 1, l = 0)
 dat_a0m1l1 = process_data(dat, a = 0, m = 1, l = 1)
 dat_a1m0l0 = process_data(dat, a = 1, m = 0, l = 0)
 dat_a1m0l1 = process_data(dat, a = 1, m = 0, l = 1)
 dat_a1m1l0 = process_data(dat, a = 1, m = 1, l = 0)
 dat_a1m1l1 = process_data(dat, a = 1, m = 1, l = 1)
 
 dat_a0m0 = process_data(dat, a = 0, m = 0, l = dat$L)
 dat_a0m1 = process_data(dat, a = 0, m = 1, l = dat$L)
 dat_a1m0 = process_data(dat, a = 1, m = 0, l = dat$L)
 dat_a1m1 = process_data(dat, a = 1, m = 1, l = dat$L)
 
 dat_a0 = process_data(dat, a = 0, m = dat$M, l = dat$L)
 dat_a1 = process_data(dat, a = 1, m = dat$M, l = dat$L)

 # p(L | M, A, C)
 idx_l = match(attributes(beta_l)$names, colnames(dat))
 p_l1a0m0 = 1/(1 + exp(-dat_a0m0[, idx_l]%*%beta_l))
 p_l0a0m0 = 1 - p_l1a0m0
 p_l1a0m1 = 1/(1 + exp(-dat_a0m1[, idx_l]%*%beta_l))
 p_l0a0m1 = 1 - p_l1a0m1
 p_l1a1m0 = 1/(1 + exp(-dat_a1m0[, idx_l]%*%beta_l))
 p_l0a1m0 = 1 - p_l1a1m0
 p_l1a1m1 = 1/(1 + exp(-dat_a1m1[, idx_l]%*%beta_l))
 p_l0a1m1 = 1 - p_l1a1m1
 
 # p(M | A, C)
 idx_m = match(attributes(beta_m)$names, colnames(dat))
 p_m1a0 = 1/(1 + exp(-dat_a0[, idx_m]%*%beta_m))
 p_m0a0 = 1 - p_m1a0
 p_m1a1 = 1/(1 + exp(-dat_a1[, idx_m]%*%beta_m))
 p_m0a1 = 1 - p_m1a1
 
 
 if (reparam == FALSE){
  
  idx_y = match(attributes(beta_y)$names, colnames(dat))
  y_a0m0l0 = dat_a0m0l0[, idx_y]%*%beta_y
  y_a0m0l1 = dat_a0m0l1[, idx_y]%*%beta_y
  y_a0m1l0 = dat_a0m1l0[, idx_y]%*%beta_y
  y_a0m1l1 = dat_a0m1l1[, idx_y]%*%beta_y
  y_a1m0l0 = dat_a1m0l0[, idx_y]%*%beta_y
  y_a1m0l1 = dat_a1m0l1[, idx_y]%*%beta_y
  y_a1m1l0 = dat_a1m1l0[, idx_y]%*%beta_y
  y_a1m1l1 = dat_a1m1l1[, idx_y]%*%beta_y
  
 }else{
  
  p = length(beta_y)
  beta_f = beta_y[1:(p-2)]
  w0 = beta_y[p-1]
  wa = beta_y[p]
  idx_f = match(attributes(beta_f)$names, colnames(dat)) 
  f_l0m0a0 = dat_a0m0l0[, idx_f]%*%beta_f
  f_l0m0a1 = dat_a1m0l0[, idx_f]%*%beta_f
  f_l0m1a0 = dat_a0m1l0[, idx_f]%*%beta_f
  f_l0m1a1 = dat_a1m1l0[, idx_f]%*%beta_f
  f_l1m0a0 = dat_a0m0l1[, idx_f]%*%beta_f
  f_l1m0a1 = dat_a1m0l1[, idx_f]%*%beta_f
  f_l1m1a0 = dat_a0m1l1[, idx_f]%*%beta_f
  f_l1m1a1 = dat_a1m1l1[, idx_f]%*%beta_f
  
  f_A1 = f_l0m0a1*p_l0a0m0*p_m0a1 + f_l0m1a1*p_l0a0m1*p_m1a1 + f_l1m0a1*p_l1a0m0*p_m0a1 + f_l1m1a1*p_l1a0m1*p_m1a1
  f_A0 = f_l0m0a0*p_l0a0m0*p_m0a1 + f_l0m1a0*p_l0a0m1*p_m1a1 + f_l1m0a0*p_l1a0m0*p_m0a1 + f_l1m1a0*p_l1a0m1*p_m1a1
  
  # E[Y | A = 1, M = 0, L = 0, C]
  y_a1m0l0 = f_l0m0a1 - sum(px*f_A1) + w0 + wa
  
  # E[Y | A = 0, M = 0, L = 0, C]
  y_a0m0l0 = f_l0m0a0 - sum(px*f_A0) + w0
  
  # E[Y | A = 1, M = 0, L = 1, C]
  y_a1m0l1 = f_l1m0a1 - sum(px*f_A1) + w0 + wa
  
  # E[Y | A = 0, M = 0, L = 1, C]
  y_a0m0l1 = f_l1m0a0 - sum(px*f_A0) + w0
 
  # E[Y | A = 1, M = 1, L = 0, C]
  y_a1m1l0 = f_l0m1a1 - sum(px*f_A1) + w0 + wa
  
  # E[Y | A = 0, M = 1, L = 0, C]
  y_a0m1l0 = f_l0m1a0 - sum(px*f_A0) + w0
  
  # E[Y | A = 1, M = 1, L = 1, C]
  y_a1m1l1 = f_l1m1a1 - sum(px*f_A1) + w0 + wa
  
  # E[Y | A = 0, M = 1, L = 0, C]
  y_a0m1l1 = f_l1m1a0 - sum(px*f_A0) + w0
 }

 effect = sum( px *( (y_a1m0l0*p_m0a1 - y_a0m0l0*p_m0a0)*p_l0a0m0 + 
                     (y_a1m0l1*p_m0a1 - y_a0m0l1*p_m0a0 )*p_l1a0m0 + 
                     (y_a1m1l0*p_m1a1 - y_a0m1l0*p_m1a0)*p_l0a0m1 + 
                     (y_a1m1l1*p_m1a1 - y_a0m1l1*p_m1a0 )*p_l1a0m1 ))

 return(effect)
 }


